{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial we will learn how to build, compose, and transform Iteration/Expression Trees (IETs).\n",
    "\n",
    "# Part II - Bottom Up\n",
    "\n",
    "`Dimensions` are the building blocks of both `Iterations` and `Expressions`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'i': i, 'j': j, 'k': k, 't0': t0, 't1': t1}"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from devito import SpaceDimension, TimeDimension\n",
    "\n",
    "dims = {'i': SpaceDimension(name='i'),\n",
    "        'j': SpaceDimension(name='j'),\n",
    "        'k': SpaceDimension(name='k'),\n",
    "        't0': TimeDimension(name='t0'),\n",
    "        't1': TimeDimension(name='t1')}\n",
    "\n",
    "dims"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Elements such as `Scalars`, `Constants` and `Functions` are used to build SymPy equations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'a': a, 'b': b, 'c': c[i], 'd': d[j, k], 'e': e[t0, t1, i], 'f': f[t, x, y]}"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from devito import Grid, Constant, Function, TimeFunction\n",
    "from devito.types import Array, Scalar\n",
    "\n",
    "grid = Grid(shape=(10, 10))\n",
    "symbs = {'a': Scalar(name='a'),\n",
    "         'b': Constant(name='b'),\n",
    "         'c': Array(name='c', shape=(3,), dimensions=(dims['i'],)).indexify(),\n",
    "         'd': Array(name='d', \n",
    "                    shape=(3,3), \n",
    "                    dimensions=(dims['j'],dims['k'])).indexify(),\n",
    "         'e': Function(name='e', \n",
    "                       shape=(3,3,3), \n",
    "                       dimensions=(dims['t0'],dims['t1'],dims['i'])).indexify(),\n",
    "         'f': TimeFunction(name='f', grid=grid).indexify()}\n",
    "symbs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "An IET `Expression` wraps a SymPy equation. Below, `DummyEq` is a subclass of `sympy.Eq` with some metadata attached. What, when and how metadata are attached is here irrelevant. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<Expression a = b + c[i] + 5.0>\n",
      "<Expression d[j, k] = e[t0, t1, i] - f[t, x, y]>\n",
      "<Expression a = 4*a*b>\n",
      "<Expression a = 8.0*a + 6.0/b>\n"
     ]
    }
   ],
   "source": [
    "from devito.ir.iet import Expression\n",
    "from devito.ir.equations import DummyEq\n",
    "from devito.tools import pprint\n",
    "\n",
    "def get_exprs(a, b, c, d, e, f):\n",
    "    return [Expression(DummyEq(a, b + c + 5.)),\n",
    "            Expression(DummyEq(d, e - f)),\n",
    "            Expression(DummyEq(a, 4 * (b * a))),\n",
    "            Expression(DummyEq(a, (6. / b) + (8. * a)))]\n",
    "\n",
    "exprs = get_exprs(symbs['a'],\n",
    "                  symbs['b'],\n",
    "                  symbs['c'],\n",
    "                  symbs['d'],\n",
    "                  symbs['e'],\n",
    "                  symbs['f'])\n",
    "\n",
    "pprint(exprs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "An `Iteration` typically wraps one or more `Expression`s. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from devito.ir.iet import Iteration\n",
    "\n",
    "def get_iters(dims):\n",
    "    return [lambda ex: Iteration(ex, dims['i'], (0, 3, 1)),\n",
    "            lambda ex: Iteration(ex, dims['j'], (0, 5, 1)),\n",
    "            lambda ex: Iteration(ex, dims['k'], (0, 7, 1)),\n",
    "            lambda ex: Iteration(ex, dims['t0'], (0, 4, 1)),\n",
    "            lambda ex: Iteration(ex, dims['t1'], (0, 4, 1))]\n",
    "\n",
    "iters = get_iters(dims)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here, we can see how blocks of `Iterations` over `Expressions` can be used to build loop nests. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<Iteration i::i::(0, 3, 1)>\n",
      "  <Iteration j::j::(0, 5, 1)>\n",
      "    <Iteration k::k::(0, 7, 1)>\n",
      "      <Expression a = b + c[i] + 5.0>\n",
      "\n",
      "\n",
      "<Iteration i::i::(0, 3, 1)>\n",
      "  <Expression a = b + c[i] + 5.0>\n",
      "  <Iteration j::j::(0, 5, 1)>\n",
      "    <Iteration k::k::(0, 7, 1)>\n",
      "      <Expression d[j, k] = e[t0, t1, i] - f[t, x, y]>\n",
      "\n",
      "\n",
      "<Iteration i::i::(0, 3, 1)>\n",
      "  <Iteration t0::t0::(0, 4, 1)>\n",
      "    <Expression a = b + c[i] + 5.0>\n",
      "  <Iteration j::j::(0, 5, 1)>\n",
      "    <Iteration k::k::(0, 7, 1)>\n",
      "      <Expression d[j, k] = e[t0, t1, i] - f[t, x, y]>\n",
      "      <Expression a = 4*a*b>\n",
      "  <Iteration t1::t1::(0, 4, 1)>\n",
      "    <Expression a = 8.0*a + 6.0/b>\n"
     ]
    }
   ],
   "source": [
    "def get_block1(exprs, iters):\n",
    "    # Perfect loop nest:\n",
    "    # for i\n",
    "    #   for j\n",
    "    #     for k\n",
    "    #       expr0\n",
    "    return iters[0](iters[1](iters[2](exprs[0])))\n",
    "    \n",
    "def get_block2(exprs, iters):\n",
    "    # Non-perfect simple loop nest:\n",
    "    # for i\n",
    "    #   expr0\n",
    "    #   for j\n",
    "    #     for k\n",
    "    #       expr1\n",
    "    return iters[0]([exprs[0], iters[1](iters[2](exprs[1]))])\n",
    "\n",
    "def get_block3(exprs, iters):\n",
    "    # Non-perfect non-trivial loop nest:\n",
    "    # for i\n",
    "    #   for s\n",
    "    #     expr0\n",
    "    #   for j\n",
    "    #     for k\n",
    "    #       expr1\n",
    "    #       expr2\n",
    "    #   for p\n",
    "    #     expr3\n",
    "    return iters[0]([iters[3](exprs[0]),\n",
    "                     iters[1](iters[2]([exprs[1], exprs[2]])),\n",
    "                     iters[4](exprs[3])])\n",
    "\n",
    "block1 = get_block1(exprs, iters)\n",
    "block2 = get_block2(exprs, iters)\n",
    "block3 = get_block3(exprs, iters)\n",
    "\n",
    "pprint(block1), print('\\n')\n",
    "pprint(block2), print('\\n')\n",
    "pprint(block3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And, finally, we can build `Callable` _kernels_ that will be used to generate C code. Note that `Operator` is a subclass of `Callable`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "kernel no.1:\n",
      "void foo()\n",
      "{\n",
      "  for (int i = 0; i <= 3; i += 1)\n",
      "  {\n",
      "    for (int j = 0; j <= 5; j += 1)\n",
      "    {\n",
      "      for (int k = 0; k <= 7; k += 1)\n",
      "      {\n",
      "        a = b + c[i] + 5.0F;\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "}\n",
      "\n",
      "kernel no.2:\n",
      "void foo()\n",
      "{\n",
      "  for (int i = 0; i <= 3; i += 1)\n",
      "  {\n",
      "    a = b + c[i] + 5.0F;\n",
      "    for (int j = 0; j <= 5; j += 1)\n",
      "    {\n",
      "      for (int k = 0; k <= 7; k += 1)\n",
      "      {\n",
      "        d[j][k] = e[t0][t1][i] - f[t][x][y];\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "}\n",
      "\n",
      "kernel no.3:\n",
      "void foo()\n",
      "{\n",
      "  for (int i = 0; i <= 3; i += 1)\n",
      "  {\n",
      "    for (int t0 = 0; t0 <= 4; t0 += 1)\n",
      "    {\n",
      "      a = b + c[i] + 5.0F;\n",
      "    }\n",
      "    for (int j = 0; j <= 5; j += 1)\n",
      "    {\n",
      "      for (int k = 0; k <= 7; k += 1)\n",
      "      {\n",
      "        d[j][k] = e[t0][t1][i] - f[t][x][y];\n",
      "        a = 4*a*b;\n",
      "      }\n",
      "    }\n",
      "    for (int t1 = 0; t1 <= 4; t1 += 1)\n",
      "    {\n",
      "      a = 8.0F*a + 6.0F/b;\n",
      "    }\n",
      "  }\n",
      "}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from devito.ir.iet import Callable\n",
    "\n",
    "kernels = [Callable('foo', block1, 'void', ()),\n",
    "           Callable('foo', block2, 'void', ()),\n",
    "           Callable('foo', block3, 'void', ())]\n",
    "\n",
    "print('kernel no.1:\\n' + str(kernels[0].ccode) + '\\n')\n",
    "print('kernel no.2:\\n' + str(kernels[1].ccode) + '\\n')\n",
    "print('kernel no.3:\\n' + str(kernels[2].ccode) + '\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "An IET is immutable. It can be \"transformed\" by replacing or dropping some of its inner nodes, but what this actually means is that a new IET is created. IETs are transformed by `Transformer` visitors. A `Transformer` takes in input a dictionary encoding replacement rules."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "void foo()\n",
      "{\n",
      "  for (int i = 0; i <= 3; i += 1)\n",
      "  {\n",
      "    a = b + c[i] + 5.0F;\n",
      "    for (int j = 0; j <= 5; j += 1)\n",
      "    {\n",
      "      for (int k = 0; k <= 7; k += 1)\n",
      "      {\n",
      "        d[j][k] = e[t0][t1][i] - f[t][x][y];\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "from devito.ir.iet import Transformer\n",
    "\n",
    "# Replaces a Function's body with another\n",
    "transformer = Transformer({block1: block2})\n",
    "kernel_alt = transformer.visit(kernels[0])\n",
    "print(kernel_alt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Specific `Expression`s within the loop nest can also be substituted."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "for (int i = 0; i <= 3; i += 1)\n",
      "{\n",
      "  for (int j = 0; j <= 5; j += 1)\n",
      "  {\n",
      "    for (int k = 0; k <= 7; k += 1)\n",
      "    {\n",
      "      d[j][k] = e[t0][t1][i] - f[t][x][y];\n",
      "    }\n",
      "  }\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "# Replaces an expression with another\n",
    "transformer = Transformer({exprs[0]: exprs[1]})\n",
    "newblock = transformer.visit(block1)\n",
    "newcode = str(newblock.ccode)\n",
    "print(newcode)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "for (int i = 0; i <= 3; i += 1)\n",
      "{\n",
      "  a = b + c[i] + 5.0F;\n",
      "  for (int j = 0; j <= 5; j += 1)\n",
      "  {\n",
      "    for (int k = 0; k <= 7; k += 1)\n",
      "    {\n",
      "      // Replaced expression\n",
      "      {\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "from devito.ir.iet import Block\n",
    "import cgen as c\n",
    "\n",
    "# Creates a replacer for replacing an expression\n",
    "line1 = '// Replaced expression'\n",
    "replacer = Block(c.Line(line1))\n",
    "transformer = Transformer({exprs[1]: replacer})\n",
    "newblock = transformer.visit(block2)\n",
    "newcode = str(newblock.ccode)\n",
    "print(newcode)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "for (int i = 0; i <= 3; i += 1)\n",
      "{\n",
      "  for (int j = 0; j <= 5; j += 1)\n",
      "  {\n",
      "    for (int k = 0; k <= 7; k += 1)\n",
      "    {\n",
      "      // This is the opening comment\n",
      "      {\n",
      "        a = b + c[i] + 5.0F;\n",
      "      }\n",
      "      // This is the closing comment\n",
      "    }\n",
      "  }\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "# Wraps an expression in comments\n",
    "line1 = '// This is the opening comment'\n",
    "line2 = '// This is the closing comment'\n",
    "wrapper = lambda n: Block(c.Line(line1), n, c.Line(line2))\n",
    "transformer = Transformer({exprs[0]: wrapper(exprs[0])})\n",
    "newblock = transformer.visit(block1)\n",
    "newcode = str(newblock.ccode)\n",
    "print(newcode)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
